import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def predict(messages, model, tokenizer):
    if torch.backends.mps.is_available():
        device = "mps"
    elif torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=2048)
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return response


tokenizer = AutoTokenizer.from_pretrained("./AI4Chem/ChemLLM-7B-Chat", use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("./AI4Chem/ChemLLM-7B-Chat", device_map="auto", torch_dtype=torch.bfloat16)

model = PeftModel.from_pretrained(model, model_id="./output/CatalystLLM/checkpoint-1082")

test_texts = {
    'instruction': "You are a helpful chemistry expert with extensive knowledge of catalysis. You will provide advanced insights and analysis on structural control methods for catalysts. Please refrain from including any disclaimers or annotations regarding your catalytic knowledge. \n What is the significance of using Cu steps on reconstructed surfaces for CO2 reduction, and how does this structural regulation enhance catalytic activity?",
    'input': "a Cu electrocatalyst with abundant step sites."
}

instruction = test_texts['instruction']
input_value = test_texts['input']

messages = [
    {"role": "system", "content": f"{instruction}"},
    {"role": "user", "content": f"{input_value}"}
]

response = predict(messages, model, tokenizer)
print(response)